from copy import deepcopy
import os
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
from tensorboardX import SummaryWriter

from losses.square_loss import MSELoss

import argparse

torch.set_default_tensor_type(torch.DoubleTensor)


parser = argparse.ArgumentParser(description='Numerical Results')
parser.add_argument('--p', default=512, type=int, help='the dimension of features and protoypes')
parser.add_argument('--num-classes', default=100, type=int, help='the number of classes')
parser.add_argument('--num-per-class', default=10, type=int, help='the number of sample in each class')
parser.add_argument('--lamb', default=0.1, type=float)
parser.add_argument('--gamma', default=0.0, type=float)
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--exp', default='exp', type=str)
parser.add_argument('--seed', default=123, type=int, help='random seed')

args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
torch.backends.cudnn.enabled =True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.determinstic = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'

random.seed(args.seed)
np.random.seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)

C = args.num_classes
N = args.num_per_class
p = args.p


class AveragedSampleMarginLossLoss(nn.Module):
    def __init__(self, gamma=0.0):
        super().__init__()
        self.gamma = gamma
    
    def forward(self, logits, labels):
        label_one_hot = F.one_hot(labels, logits.size()[1]).float().to(logits.device)
        l1 = torch.sum(logits * label_one_hot, dim=-1)
        l2 = torch.sum(logits * (1 - label_one_hot), dim=-1)
        loss = -l1 + self.gamma * l2
        return loss.mean()


def evaluate(out, labels):
    pred = torch.argmax(out, 1)
    total = labels.size(0)
    correct = (pred==labels).sum().item()
    acc = float(correct) / float(total)
    return acc


def get_margin(weight):
    tmp = F.normalize(weight, dim=1)
    similarity = torch.matmul(tmp, tmp.transpose(1, 0)) - 2 * torch.eye(tmp.size(0), device=weight.device)
    similarity = torch.clamp(similarity, -1+1e-7, 1-1e-7)
    return torch.acos(torch.max(similarity)).item() / math.pi * 180


labels = [i for i in range(C)] * N
labels = torch.LongTensor(labels).to(device)
H = torch.randn(C * N, p).to(device)
W = torch.randn(C, p).to(device)
# nn.init.kaiming_uniform_(W)

H.requires_grad = True
W.requires_grad = False

lr = args.lr
gamma = args.gamma
lamb = args.lamb

optimizer = torch.optim.SGD([{'params': H, 'lr': lr}], weight_decay=lamb)
criterion = AveragedSampleMarginLossLoss(gamma=gamma)
store_name = './log/pal/' + args.exp + '/dim={}, C={}, N={}, lambda={}, gamma={}, lr={}'.format(p, C, N, lamb, gamma, lr)
tf_writer = SummaryWriter(log_dir=store_name)

acc_list = []
norm_h_list = []
error_list = []
gap_list = []
loss_list = []

H0 = deepcopy(H)


def dynamic(H0, lam, eta, W, t):
    if lam > 0:
        const = math.exp(-lam * eta * t)
        return const * H0 + (1 - const) / C / N / lam * W.repeat((N, 1))
    else:
        return H0 + eta * t / C / N * W.repeat((N, 1))

epochs = 50000
for ep in range(epochs):
    out = F.linear(H, W)
    loss = criterion(out, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    acc = evaluate(out, labels)

    gap = torch.norm(H - dynamic(H0, lamb, lr, W, ep+1))
    error = torch.norm(H / torch.norm(H) - W.repeat((N, 1)) / torch.norm(W.repeat((N, 1))))
    norm_h = torch.norm(H)

    loss_list.append(loss.item())
    acc_list.append(acc)
    gap_list.append(gap.item())
    error_list.append(error.item())
    norm_h_list.append(norm_h.item())


    tf_writer.add_scalar('acc', acc, ep)
    tf_writer.add_scalar('H', norm_h, ep)
    tf_writer.add_scalar('error', error.item(), ep)
    tf_writer.add_scalar('gap', gap.item(), ep)
    tf_writer.add_scalar('loss', loss.item(), ep)
    if ep % 200 ==0:
        print('Iter {}: loss={:.4f}, acc={:.4f}, norm_h={:.4f}, gap={:.4f}, error={:.4f},'.format(ep, loss.item(), acc, norm_h.item(), gap.item(), error.item()))
    torch.cuda.empty_cache()


acc_list = np.array(acc_list)
norm_h_list = np.array(norm_h_list)
error_list = np.array(error_list)
gap_list = np.array(gap_list)
loss_list = np.array(loss_list)

np.save(store_name + '/loss.npy', loss_list)
np.save(store_name + '/acc.npy', acc_list)
np.save(store_name + '/gap.npy', gap_list)
np.save(store_name + '/error.npy', error_list)
np.save(store_name + '/norm_h.npy', norm_h_list)